def test_delta_log_prob(): v = torch.tensor([[3.67], [8.91], [-76.213]]) delta_dist = custom_distributions.Delta(v) v2 = v.clone() lp = delta_dist.log_prob(v2) assert np.all(lp.detach().numpy() == 0.) v3 = v.clone() v3[1] = 8.45 lp = delta_dist.log_prob(v3) assert np.all(lp.detach().numpy() == np.array([0., -np.inf, 0.]))
def test_conditional_delta_diag(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=False, whiten=False) expected_mean = np.array([[ 1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943, -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969 ], [ -1.6721866760, 0.6966202877, 3.5202238201, 6.3271136175, 0.5595452083, -0.9492091724, -2.9775444011, 4.6694232859 ]]) expected_var = np.array([[ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ], [ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_diag_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats cov_aa = torch.diag(cov_aa) mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=False, whiten=True) expected_mean = np.array([[ 6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877, -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024 ], [ 0.2575228726, 0.6755548289, 7.6874194130, 15.9547405928, 1.1266581340, -0.2008611798, -3.7909059985, 10.0338506274 ]]) expected_var = np.array([[ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ], [ 4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747, 3.4529913441, 3.8891409966, 5.7185561873 ]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_full_cov_whiten(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=True, whiten=True) expected_mean = np.array([[ 6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877, -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024 ], [ 0.2575228726, 0.6755548289, 7.6874194130, 15.9547405928, 1.1266581340, -0.2008611798, -3.7909059985, 10.0338506274 ]]) expected_var = np.array( [[[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]], [[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)
def test_conditional_delta_full_cov(precomputed_cov_mats): f_vals = torch.tensor( [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]], dtype=utils.TORCH_FLOAT_TYPE) f_vals = custom_distributions.Delta(f_vals) cov_aa, cov_ba, cov_bb = precomputed_cov_mats mean, var = conditionals.conditional_gaussian(cov_aa, cov_ba, cov_bb, f_vals, return_full_cov_flag=True, whiten=False) expected_mean = np.array([[ 1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943, -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969 ], [ -1.6721866760, 0.6966202877, 3.5202238201, 6.3271136175, 0.5595452083, -0.9492091724, -2.9775444011, 4.6694232859 ]]) expected_var = np.array( [[[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]], [[ 4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01, -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01, 6.8652224581e-01, -5.9944796542e-01 ], [ -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02, 3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01, -4.0177948193e-01, 5.5950707268e-02 ], [ -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00, 1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01, -6.0924748295e-03, 2.4712302817e+00 ], [ -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00, 1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01, -1.4162076285e-01, 8.5758006162e-01 ], [ -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01, 9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02, 8.4467531841e-03, 2.2271565406e+00 ], [ 7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01, -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00, 1.3024795901e+00, -3.5594698486e-01 ], [ 6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03, -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00, 3.8891409966e+00, 2.9443812570e-03 ], [ -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00, 8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01, 2.9443812570e-03, 5.7185561873e+00 ]]]) np.testing.assert_array_almost_equal(mean.numpy(), expected_mean) np.testing.assert_array_almost_equal(var.numpy(), expected_var)