def test_delta_log_prob():
    v = torch.tensor([[3.67], [8.91], [-76.213]])
    delta_dist = custom_distributions.Delta(v)

    v2 = v.clone()
    lp = delta_dist.log_prob(v2)
    assert np.all(lp.detach().numpy() == 0.)

    v3 = v.clone()
    v3[1] = 8.45
    lp = delta_dist.log_prob(v3)
    assert np.all(lp.detach().numpy() == np.array([0., -np.inf, 0.]))
예제 #2
0
def test_conditional_delta_diag(precomputed_cov_mats):

    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=False,
                                                  whiten=False)

    expected_mean = np.array([[
        1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943,
        -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969
    ],
                              [
                                  -1.6721866760, 0.6966202877, 3.5202238201,
                                  6.3271136175, 0.5595452083, -0.9492091724,
                                  -2.9775444011, 4.6694232859
                              ]])
    expected_var = np.array([[
        4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747,
        3.4529913441, 3.8891409966, 5.7185561873
    ],
                             [
                                 4.4402744540, 3.1750354618, 6.2603802337,
                                 1.7130353089, 7.1050700747, 3.4529913441,
                                 3.8891409966, 5.7185561873
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
예제 #3
0
def test_conditional_delta_diag_whiten(precomputed_cov_mats):

    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    cov_aa = torch.diag(cov_aa)
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=False,
                                                  whiten=True)

    expected_mean = np.array([[
        6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877,
        -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024
    ],
                              [
                                  0.2575228726, 0.6755548289, 7.6874194130,
                                  15.9547405928, 1.1266581340, -0.2008611798,
                                  -3.7909059985, 10.0338506274
                              ]])
    expected_var = np.array([[
        4.4402744540, 3.1750354618, 6.2603802337, 1.7130353089, 7.1050700747,
        3.4529913441, 3.8891409966, 5.7185561873
    ],
                             [
                                 4.4402744540, 3.1750354618, 6.2603802337,
                                 1.7130353089, 7.1050700747, 3.4529913441,
                                 3.8891409966, 5.7185561873
                             ]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
예제 #4
0
def test_conditional_delta_full_cov_whiten(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=True,
                                                  whiten=True)

    expected_mean = np.array([[
        6.6308755724, 7.3815405044, -2.0869767152, -1.6895302877,
        -0.4139590835, 9.1604886668, 10.8441638657, -2.9321998024
    ],
                              [
                                  0.2575228726, 0.6755548289, 7.6874194130,
                                  15.9547405928, 1.1266581340, -0.2008611798,
                                  -3.7909059985, 10.0338506274
                              ]])
    expected_var = np.array(
        [[[
            4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
            -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
            6.8652224581e-01, -5.9944796542e-01
        ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]],
         [[
             4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
             -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
             6.8652224581e-01, -5.9944796542e-01
         ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)
예제 #5
0
def test_conditional_delta_full_cov(precomputed_cov_mats):
    f_vals = torch.tensor(
        [[1.23, -0.23, 7.45, 2.45, -3.45], [0.85, 6.4, -8.21, -0.45, -0.7]],
        dtype=utils.TORCH_FLOAT_TYPE)
    f_vals = custom_distributions.Delta(f_vals)

    cov_aa, cov_ba, cov_bb = precomputed_cov_mats
    mean, var = conditionals.conditional_gaussian(cov_aa,
                                                  cov_ba,
                                                  cov_bb,
                                                  f_vals,
                                                  return_full_cov_flag=True,
                                                  whiten=False)

    expected_mean = np.array([[
        1.6243021303, 3.9762279529, -1.0018448174, -0.8217521943,
        -0.2007619034, 3.9222665183, 5.0673831602, -1.4141233969
    ],
                              [
                                  -1.6721866760, 0.6966202877, 3.5202238201,
                                  6.3271136175, 0.5595452083, -0.9492091724,
                                  -2.9775444011, 4.6694232859
                              ]])
    expected_var = np.array(
        [[[
            4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
            -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
            6.8652224581e-01, -5.9944796542e-01
        ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]],
         [[
             4.4402744540e+00, -8.8794111845e-01, -5.2494440661e-01,
             -6.7909241880e-02, -6.9165843075e-02, 7.7879338152e-01,
             6.8652224581e-01, -5.9944796542e-01
         ],
          [
              -8.8794111845e-01, 3.1750354618e+00, 4.0894705634e-02,
              3.9688953050e-02, 4.1088182004e-03, -4.3666741391e-01,
              -4.0177948193e-01, 5.5950707268e-02
          ],
          [
              -5.2494440661e-01, 4.0894705634e-02, 6.2603802337e+00,
              1.0111468375e+00, 7.3081832061e-01, -1.7098283477e-01,
              -6.0924748295e-03, 2.4712302817e+00
          ],
          [
              -6.7909241880e-02, 3.9688953050e-02, 1.0111468375e+00,
              1.7130353089e+00, 9.0428140734e-02, -2.2654425360e-01,
              -1.4162076285e-01, 8.5758006162e-01
          ],
          [
              -6.9165843075e-02, 4.1088182004e-03, 7.3081832061e-01,
              9.0428140734e-02, 7.1050700747e+00, -3.6494969401e-02,
              8.4467531841e-03, 2.2271565406e+00
          ],
          [
              7.7879338152e-01, -4.3666741391e-01, -1.7098283477e-01,
              -2.2654425360e-01, -3.6494969401e-02, 3.4529913441e+00,
              1.3024795901e+00, -3.5594698486e-01
          ],
          [
              6.8652224581e-01, -4.0177948193e-01, -6.0924748295e-03,
              -1.4162076285e-01, 8.4467531841e-03, 1.3024795901e+00,
              3.8891409966e+00, 2.9443812570e-03
          ],
          [
              -5.9944796542e-01, 5.5950707268e-02, 2.4712302817e+00,
              8.5758006162e-01, 2.2271565406e+00, -3.5594698486e-01,
              2.9443812570e-03, 5.7185561873e+00
          ]]])

    np.testing.assert_array_almost_equal(mean.numpy(), expected_mean)
    np.testing.assert_array_almost_equal(var.numpy(), expected_var)