def test_conditional_mvn_diag_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486], [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]], dtype=utils.TORCH_FLOAT_TYPE) f_qrt = torch.tensor( [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.], [1.05706677, -1.69033564, 0.81725719, 0., 0.], [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.], [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]], [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.], [-1.14057908, -1.58899131, 2.05868389, 0., 0.], [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.], [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]], dtype=utils.TORCH_FLOAT_TYPE) f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_dist, return_full_cov_flag=False, whiten=True) expected_mean = np.array([[ 0.9983176021, -3.9736127037, 0.6906933359, 2.0855503170, 0.0715756755, -1.2503840376, -1.8458709690, 0.8542689689 ], [ -1.1424803951, -1.3786447703, -0.9203628273, -2.2190894546, -0.1230040081, -1.4342874653, -1.0501022539, -1.1778143477 ]]) expected_var = np.array([[ 14.2586345250, 14.0537436795, 9.2947300720, 14.1982420873, 7.1713542879, 19.1857305446, 18.6136875294, 10.9176187167 ], [ 14.4353048611, 25.0036833948, 11.4511969440, 27.2123889213, 7.2121129353, 27.2155613383, 22.9932268895, 14.4766175341 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_mvn_diag(precomputed_cov_mats): f_vals = torch.tensor( [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486], [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]], dtype=utils.TORCH_FLOAT_TYPE) f_qrt = torch.tensor( [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.], [1.05706677, -1.69033564, 0.81725719, 0., 0.], [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.], [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]], [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.], [-1.14057908, -1.58899131, 2.05868389, 0., 0.], [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.], [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]], dtype=utils.TORCH_FLOAT_TYPE) f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_dist, return_full_cov_flag=False, whiten=False) expected_mean = np.array([[ 1.0395413577, -1.8514740569, 0.3141941933, 0.8476137584, 0.0382751075, -0.3854980598, -0.8345681935, 0.3999459519 ], [ -0.0090321213, -0.7372578999, -0.4162285925, -0.8602622449, -0.0616100437, -0.4430488838, -0.2661613907, -0.5424388696 ]]) expected_var = np.array([[ 5.5969765530, 6.3781133317, 6.9028852306, 3.7262482077, 7.1214550454, 7.0502994187, 7.7957730662, 6.8558710306 ], [ 5.6527201408, 12.4449391102, 7.3493964930, 5.7165059115, 7.1316647669, 10.6323246135, 10.7472708960, 7.6216942496 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_diag(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=False, whiten=False) expected_mean = np.array([[ 1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943, -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969 ], [ -1.6721866760, 0.6966202877, 3.5202238201, 6.3271136175, 0.5595452083, -0.9492091724, -2.9775444011, 4.6694232859 ]]) expected_var = np.array([[ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ], [ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_diag_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=False, whiten=True) expected_mean = np.array([[ 6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877, -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024 ], [ 0.2575228726, 0.6755548289, 7.6874194130, 15.9547405928, 1.1266581340, -0.2008611798, -3.7909059985, 10.0338506274 ]]) expected_var = np.array([[ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ], [ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_mvn_full_cov_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486], [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]], dtype=utils.TORCH_FLOAT_TYPE) f_qrt = torch.tensor( [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.], [1.05706677, -1.69033564, 0.81725719, 0., 0.], [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.], [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]], [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.], [-1.14057908, -1.58899131, 2.05868389, 0., 0.], [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.], [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]], dtype=utils.TORCH_FLOAT_TYPE) f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_dist, return_full_cov_flag=True, whiten=True) expected_mean = np.array([[ 0.9983176021, -3.9736127037, 0.6906933359, 2.0855503170, 0.0715756755, -1.2503840376, -1.8458709690, 0.8542689689 ], [ -1.1424803951, -1.3786447703, -0.9203628273, -2.2190894546, -0.1230040081, -1.4342874653, -1.0501022539, -1.1778143477 ]]) expected_var = np.array( [[[ 14.2586345250, 9.1822888138, -2.6004046698, -4.0531016056, -0.3628278468, 12.8795825333, 11.8793483846, -3.3109476844 ], [ 9.1822888138, 14.0537436795, -3.0078651759, -5.9977582226, -0.4290556306, 12.5841457955, 11.9920397559, -3.9239952650 ], [ -2.6004046698, -3.0078651759, 9.2947300720, 7.1303199998, 1.1779780651, -3.5088886346, -4.5304743531, 6.4427563661 ], [ -4.0531016056, -5.9977582226, 7.1303199998, 14.1982420873, 0.9852830967, -6.9005426267, -9.0417278718, 8.8555650766 ], [ -0.3628278468, -0.4290556306, 1.1779780651, 0.9852830967, 7.1713542879, -0.5043240844, -0.6476378715, 2.8130058040 ], [ 12.8795825333, 12.5841457955, -3.5088886346, -6.9005426267, -0.5043240844, 19.1857305446, 15.9460640932, -4.7058678889 ], [ 11.8793483846, 11.9920397559, -4.5304743531, -9.0417278718, -0.6476378715, 15.9460640932, 18.6136875294, -5.9171923298 ], [ -3.3109476844, -3.9239952650, 6.4427563661, 8.8555650766, 2.8130058040, -4.7058678889, -5.9171923298, 10.9176187167 ]], [[ 14.4353048611, 7.6440115383, 1.0779014652, 5.6157728594, 0.0847217439, 14.3465260481, 11.8000939369, 1.3130612969 ], [ 7.6440115383, 25.0036833948, -1.0122776158, -2.6434071073, -0.0927873657, 19.4050847421, 17.3934295022, -1.2793426390 ], [ 1.0779014652, -1.0122776158, 11.4511969440, 12.2886570287, 1.4698466940, -0.2455049374, -2.5919774292, 9.2114657874 ], [ 5.6157728594, -2.6434071073, 12.2886570287, 27.2123889213, 1.6542042466, 1.3498022098, -4.1299704160, 15.4245118703 ], [ 0.0847217439, -0.0927873657, 1.4698466940, 1.6542042466, 7.2121129353, -0.0816546991, -0.3961193855, 3.1899994546 ], [ 14.3465260481, 19.4050847421, -0.2455049374, 1.3498022098, -0.0816546991, 27.2155613383, 21.7753020846, -0.5858162823 ], [ 11.8000939369, 17.3934295022, -2.5919774292, -4.1299704160, -0.3961193855, 21.7753020846, 22.9932268895, -3.4791494930 ], [ 1.3130612969, -1.2793426390, 9.2114657874, 15.4245118703, 3.1899994546, -0.5858162823, -3.4791494930, 14.4766175341 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_mvn_full_cov(precomputed_cov_mats): f_vals = torch.tensor( [[-2.28405212, 0.92506561, 0.6771707, -0.24862164, 1.43093486], [-0.4704259, -0.94033398, 0.39634775, -0.43562778, 0.27752191]], dtype=utils.TORCH_FLOAT_TYPE) f_qrt = torch.tensor( [[[1.2913755, 0., 0., 0., 0.], [-0.21657062, 1.38901088, 0., 0., 0.], [1.05706677, -1.69033564, 0.81725719, 0., 0.], [0.24842882, -1.07745551, -1.31826034, 2.02463485, 0.], [1.16116644, -0.09549434, -1.33062235, 0.84813075, 0.32257327]], [[2.29744629, 0., 0., 0., 0.], [-0.37817459, 2.0557189, 0., 0., 0.], [-1.14057908, -1.58899131, 2.05868389, 0., 0.], [-1.02594764, -0.03201106, -0.20093787, 2.6618695, 0.], [0.7124642, 0.47203433, 0.17535208, -1.18123066, 0.46893534]]], dtype=utils.TORCH_FLOAT_TYPE) f_dist = distributions.MultivariateNormal(f_vals, scale_tril=f_qrt) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_dist, return_full_cov_flag=True, whiten=False) expected_mean = np.array([[ 1.0395413577, -1.8514740569, 0.3141941933, 0.8476137584, 0.0382751075, -0.3854980598, -0.8345681935, 0.3999459519 ], [ -0.0090321213, -0.7372578999, -0.4162285925, -0.8602622449, -0.0616100437, -0.4430488838, -0.2661613907, -0.5424388696 ]]) expected_var = np.array( [[[ 5.5969765530, 0.8987337154, -1.0047342826, -0.9199924557, -0.1427258369, 2.6960902126, 2.6752037660, -1.2360837610 ], [ 0.8987337154, 6.3781133317, -0.6323356576, -1.1976265383, -0.0958950157, 2.9462344884, 2.9279693862, -0.8347931889 ], [ -1.0047342826, -0.6323356576, 6.9028852306, 2.1426581153, 0.8332082246, -0.9641806581, -1.1681775127, 3.3260099106 ], [ -0.9199924557, -1.1976265383, 2.1426581153, 3.7262482077, 0.2696427010, -1.6835020270, -2.2145982904, 2.3613305103 ], [ -0.1427258369, -0.0958950157, 0.8332082246, 0.2696427010, 7.1214550454, -0.1550713938, -0.1714103286, 2.3634627316 ], [ 2.6960902126, 2.9462344884, -0.9641806581, -1.6835020270, -0.1550713938, 7.0502994187, 4.8934529747, -1.4055364028 ], [ 2.6752037660, 2.9279693862, -1.1681775127, -2.2145982904, -0.1714103286, 4.8934529747, 7.7957730662, -1.5401611640 ], [ -1.2360837610, -0.8347931889, 3.3260099106, 2.3613305103, 2.3634627316, -1.4055364028, -1.5401611640, 6.8558710306 ]], [[ 5.6527201408, 0.2022671057, -0.6651704627, 0.0611261029, -0.1054391073, 2.6737827700, 2.4564956142, -0.8162208570 ], [ 0.2022671057, 12.4449391102, -0.2266653547, -0.4712621969, -0.0265284815, 7.1430618936, 6.7646918858, -0.3103502309 ], [ -0.6651704627, -0.2266653547, 7.3493964930, 3.0578181632, 0.9001054998, -0.5528150962, -1.0520779708, 3.9105112597 ], [ 0.0611261029, -0.4712621969, 3.0578181632, 5.7165059115, 0.4021351181, -0.6077735410, -1.7997565173, 3.5505399050 ], [ -0.1054391073, -0.0265284815, 0.9001054998, 0.4021351181, 7.1316647669, -0.1003523807, -0.1585082426, 2.4513695028 ], [ 2.6737827700, 7.1430618936, -0.5528150962, -0.6077735410, -0.1003523807, 10.6323246135, 8.0648547968, -0.8970605421 ], [ 2.4564956142, 6.7646918858, -1.0520779708, -1.7997565173, -0.1585082426, 8.0648547968, 10.7472708960, -1.4129403011 ], [ -0.8162208570, -0.3103502309, 3.9105112597, 3.5505399050, 2.4513695028, -0.8970605421, -1.4129403011, 7.6216942496 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_full_cov_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=True, whiten=True) expected_mean = np.array([[ 6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877, -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024 ], [ 0.2575228726, 0.6755548289, 7.6874194130, 15.9547405928, 1.1266581340, -0.2008611798, -3.7909059985, 10.0338506274 ]]) expected_var = np.array( [[[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]], [[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_full_cov(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=True, whiten=False) expected_mean = np.array([[ 1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943, -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969 ], [ -1.6721866760, 0.6966202877, 3.5202238201, 6.3271136175, 0.5595452083, -0.9492091724, -2.9775444011, 4.6694232859 ]]) expected_var = np.array( [[[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]], [[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)