Пример #1
0
def test_conditional_mvn_diag_whiten(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486],
         [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_qrt = torch.tensor(
        [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.],
          [1.05706677, -1.69033564, 0.81725719, 0., 0.],
          [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.],
          [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]],
         [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.],
          [-1.14057908, -1.58899131, 2.05868389, 0., 0.],
          [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.],
          [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_dist,
                                                  return_full_cov_flag=False,
                                                  whiten=True)

    expected_mean = np.array([[
        0.9983176021, -3.9736127037, 0.6906933359, 2.0855503170, 0.0715756755,
        -1.2503840376, -1.8458709690, 0.8542689689
    ],
                              [
                                  -1.1424803951, -1.3786447703, -0.9203628273,
                                  -2.2190894546, -0.1230040081, -1.4342874653,
                                  -1.0501022539, -1.1778143477
                              ]])
    expected_var = np.array([[
        14.2586345250, 14.0537436795, 9.2947300720, 14.1982420873,
        7.1713542879, 19.1857305446, 18.6136875294, 10.9176187167
    ],
                             [
                                 14.4353048611, 25.0036833948, 11.4511969440,
                                 27.2123889213, 7.2121129353, 27.2155613383,
                                 22.9932268895, 14.4766175341
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #2
0
def test_conditional_mvn_diag(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486],
         [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_qrt = torch.tensor(
        [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.],
          [1.05706677, -1.69033564, 0.81725719, 0., 0.],
          [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.],
          [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]],
         [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.],
          [-1.14057908, -1.58899131, 2.05868389, 0., 0.],
          [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.],
          [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_dist,
                                                  return_full_cov_flag=False,
                                                  whiten=False)

    expected_mean = np.array([[
        1.0395413577, -1.8514740569, 0.3141941933, 0.8476137584, 0.0382751075,
        -0.3854980598, -0.8345681935, 0.3999459519
    ],
                              [
                                  -0.0090321213, -0.7372578999, -0.4162285925,
                                  -0.8602622449, -0.0616100437, -0.4430488838,
                                  -0.2661613907, -0.5424388696
                              ]])
    expected_var = np.array([[
        5.5969765530, 6.3781133317, 6.9028852306, 3.7262482077, 7.1214550454,
        7.0502994187, 7.7957730662, 6.8558710306
    ],
                             [
                                 5.6527201408, 12.4449391102, 7.3493964930,
                                 5.7165059115, 7.1316647669, 10.6323246135,
                                 10.7472708960, 7.6216942496
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #3
0
def test_conditional_delta_diag(precomputed_cov_mats):

    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=False,
                                                  whiten=False)

    expected_mean = np.array([[
        1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943,
        -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969
    ],
                              [
                                  -1.6721866760, 0.6966202877, 3.5202238201,
                                  6.3271136175, 0.5595452083, -0.9492091724,
                                  -2.9775444011, 4.6694232859
                              ]])
    expected_var = np.array([[
        4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747,
        3.4529913441, 3.8891409966, 5.7185561873
    ],
                             [
                                 4.4402744540, 3.1750354618, 6.2603802337,
                                 1.7130353089, 7.1050700747, 3.4529913441,
                                 3.8891409966, 5.7185561873
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #4
0
def test_conditional_delta_diag_whiten(precomputed_cov_mats):

    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=False,
                                                  whiten=True)

    expected_mean = np.array([[
        6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877,
        -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024
    ],
                              [
                                  0.2575228726, 0.6755548289, 7.6874194130,
                                  15.9547405928, 1.1266581340, -0.2008611798,
                                  -3.7909059985, 10.0338506274
                              ]])
    expected_var = np.array([[
        4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747,
        3.4529913441, 3.8891409966, 5.7185561873
    ],
                             [
                                 4.4402744540, 3.1750354618, 6.2603802337,
                                 1.7130353089, 7.1050700747, 3.4529913441,
                                 3.8891409966, 5.7185561873
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #5
0
def test_conditional_mvn_full_cov_whiten(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486],
         [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_qrt = torch.tensor(
        [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.],
          [1.05706677, -1.69033564, 0.81725719, 0., 0.],
          [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.],
          [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]],
         [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.],
          [-1.14057908, -1.58899131, 2.05868389, 0., 0.],
          [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.],
          [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_dist,
                                                  return_full_cov_flag=True,
                                                  whiten=True)

    expected_mean = np.array([[
        0.9983176021, -3.9736127037, 0.6906933359, 2.0855503170, 0.0715756755,
        -1.2503840376, -1.8458709690, 0.8542689689
    ],
                              [
                                  -1.1424803951, -1.3786447703, -0.9203628273,
                                  -2.2190894546, -0.1230040081, -1.4342874653,
                                  -1.0501022539, -1.1778143477
                              ]])
    expected_var = np.array(
        [[[
            14.2586345250, 9.1822888138, -2.6004046698, -4.0531016056,
            -0.3628278468, 12.8795825333, 11.8793483846, -3.3109476844
        ],
          [
              9.1822888138, 14.0537436795, -3.0078651759, -5.9977582226,
              -0.4290556306, 12.5841457955, 11.9920397559, -3.9239952650
          ],
          [
              -2.6004046698, -3.0078651759, 9.2947300720, 7.1303199998,
              1.1779780651, -3.5088886346, -4.5304743531, 6.4427563661
          ],
          [
              -4.0531016056, -5.9977582226, 7.1303199998, 14.1982420873,
              0.9852830967, -6.9005426267, -9.0417278718, 8.8555650766
          ],
          [
              -0.3628278468, -0.4290556306, 1.1779780651, 0.9852830967,
              7.1713542879, -0.5043240844, -0.6476378715, 2.8130058040
          ],
          [
              12.8795825333, 12.5841457955, -3.5088886346, -6.9005426267,
              -0.5043240844, 19.1857305446, 15.9460640932, -4.7058678889
          ],
          [
              11.8793483846, 11.9920397559, -4.5304743531, -9.0417278718,
              -0.6476378715, 15.9460640932, 18.6136875294, -5.9171923298
          ],
          [
              -3.3109476844, -3.9239952650, 6.4427563661, 8.8555650766,
              2.8130058040, -4.7058678889, -5.9171923298, 10.9176187167
          ]],
         [[
             14.4353048611, 7.6440115383, 1.0779014652, 5.6157728594,
             0.0847217439, 14.3465260481, 11.8000939369, 1.3130612969
         ],
          [
              7.6440115383, 25.0036833948, -1.0122776158, -2.6434071073,
              -0.0927873657, 19.4050847421, 17.3934295022, -1.2793426390
          ],
          [
              1.0779014652, -1.0122776158, 11.4511969440, 12.2886570287,
              1.4698466940, -0.2455049374, -2.5919774292, 9.2114657874
          ],
          [
              5.6157728594, -2.6434071073, 12.2886570287, 27.2123889213,
              1.6542042466, 1.3498022098, -4.1299704160, 15.4245118703
          ],
          [
              0.0847217439, -0.0927873657, 1.4698466940, 1.6542042466,
              7.2121129353, -0.0816546991, -0.3961193855, 3.1899994546
          ],
          [
              14.3465260481, 19.4050847421, -0.2455049374, 1.3498022098,
              -0.0816546991, 27.2155613383, 21.7753020846, -0.5858162823
          ],
          [
              11.8000939369, 17.3934295022, -2.5919774292, -4.1299704160,
              -0.3961193855, 21.7753020846, 22.9932268895, -3.4791494930
          ],
          [
              1.3130612969, -1.2793426390, 9.2114657874, 15.4245118703,
              3.1899994546, -0.5858162823, -3.4791494930, 14.4766175341
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #6
0
def test_conditional_mvn_full_cov(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486],
         [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_qrt = torch.tensor(
        [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.],
          [1.05706677, -1.69033564, 0.81725719, 0., 0.],
          [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.],
          [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]],
         [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.],
          [-1.14057908, -1.58899131, 2.05868389, 0., 0.],
          [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.],
          [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_dist,
                                                  return_full_cov_flag=True,
                                                  whiten=False)

    expected_mean = np.array([[
        1.0395413577, -1.8514740569, 0.3141941933, 0.8476137584, 0.0382751075,
        -0.3854980598, -0.8345681935, 0.3999459519
    ],
                              [
                                  -0.0090321213, -0.7372578999, -0.4162285925,
                                  -0.8602622449, -0.0616100437, -0.4430488838,
                                  -0.2661613907, -0.5424388696
                              ]])
    expected_var = np.array(
        [[[
            5.5969765530, 0.8987337154, -1.0047342826, -0.9199924557,
            -0.1427258369, 2.6960902126, 2.6752037660, -1.2360837610
        ],
          [
              0.8987337154, 6.3781133317, -0.6323356576, -1.1976265383,
              -0.0958950157, 2.9462344884, 2.9279693862, -0.8347931889
          ],
          [
              -1.0047342826, -0.6323356576, 6.9028852306, 2.1426581153,
              0.8332082246, -0.9641806581, -1.1681775127, 3.3260099106
          ],
          [
              -0.9199924557, -1.1976265383, 2.1426581153, 3.7262482077,
              0.2696427010, -1.6835020270, -2.2145982904, 2.3613305103
          ],
          [
              -0.1427258369, -0.0958950157, 0.8332082246, 0.2696427010,
              7.1214550454, -0.1550713938, -0.1714103286, 2.3634627316
          ],
          [
              2.6960902126, 2.9462344884, -0.9641806581, -1.6835020270,
              -0.1550713938, 7.0502994187, 4.8934529747, -1.4055364028
          ],
          [
              2.6752037660, 2.9279693862, -1.1681775127, -2.2145982904,
              -0.1714103286, 4.8934529747, 7.7957730662, -1.5401611640
          ],
          [
              -1.2360837610, -0.8347931889, 3.3260099106, 2.3613305103,
              2.3634627316, -1.4055364028, -1.5401611640, 6.8558710306
          ]],
         [[
             5.6527201408, 0.2022671057, -0.6651704627, 0.0611261029,
             -0.1054391073, 2.6737827700, 2.4564956142, -0.8162208570
         ],
          [
              0.2022671057, 12.4449391102, -0.2266653547, -0.4712621969,
              -0.0265284815, 7.1430618936, 6.7646918858, -0.3103502309
          ],
          [
              -0.6651704627, -0.2266653547, 7.3493964930, 3.0578181632,
              0.9001054998, -0.5528150962, -1.0520779708, 3.9105112597
          ],
          [
              0.0611261029, -0.4712621969, 3.0578181632, 5.7165059115,
              0.4021351181, -0.6077735410, -1.7997565173, 3.5505399050
          ],
          [
              -0.1054391073, -0.0265284815, 0.9001054998, 0.4021351181,
              7.1316647669, -0.1003523807, -0.1585082426, 2.4513695028
          ],
          [
              2.6737827700, 7.1430618936, -0.5528150962, -0.6077735410,
              -0.1003523807, 10.6323246135, 8.0648547968, -0.8970605421
          ],
          [
              2.4564956142, 6.7646918858, -1.0520779708, -1.7997565173,
              -0.1585082426, 8.0648547968, 10.7472708960, -1.4129403011
          ],
          [
              -0.8162208570, -0.3103502309, 3.9105112597, 3.5505399050,
              2.4513695028, -0.8970605421, -1.4129403011, 7.6216942496
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #7
0
def test_conditional_delta_full_cov_whiten(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=True,
                                                  whiten=True)

    expected_mean = np.array([[
        6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877,
        -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024
    ],
                              [
                                  0.2575228726, 0.6755548289, 7.6874194130,
                                  15.9547405928, 1.1266581340, -0.2008611798,
                                  -3.7909059985, 10.0338506274
                              ]])
    expected_var = np.array(
        [[[
            4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
            -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
            6.8652224581e-01, -5.9944796542e-01
        ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]],
         [[
             4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
             -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
             6.8652224581e-01, -5.9944796542e-01
         ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
Пример #8
0
def test_conditional_delta_full_cov(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=True,
                                                  whiten=False)

    expected_mean = np.array([[
        1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943,
        -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969
    ],
                              [
                                  -1.6721866760, 0.6966202877, 3.5202238201,
                                  6.3271136175, 0.5595452083, -0.9492091724,
                                  -2.9775444011, 4.6694232859
                              ]])
    expected_var = np.array(
        [[[
            4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
            -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
            6.8652224581e-01, -5.9944796542e-01
        ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]],
         [[
             4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
             -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
             6.8652224581e-01, -5.9944796542e-01
         ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)