示例#1
0
    def test_wishart_moments(self):
        num_draws = 10000
        df = 4.3
        v = np.diag(np.array([2., 3.])) + np.full((2, 2), 0.1)
        wishart_dist = sp.stats.wishart(df=df, scale=v)
        wishart_draws = wishart_dist.rvs(num_draws)
        log_det_draws = np.linalg.slogdet(wishart_draws)[1]
        moment_tolerance = 3.0 * np.std(log_det_draws) / np.sqrt(num_draws)
        print('Wishart e log det test tolerance: ', moment_tolerance)
        np_test.assert_allclose(
            np.mean(log_det_draws), ef.e_log_det_wishart(df, v),
            atol=moment_tolerance)

        # Test the log inverse diagonals
        wishart_inv_draws = \
            [ np.linalg.inv(wishart_draws[n, :, :]) for n in range(num_draws) ]
        wishart_log_diag = \
            np.log([ np.diag(mat) for mat in wishart_inv_draws ])
        diag_mean = np.mean(wishart_log_diag, axis=0)
        diag_sd = np.std(wishart_log_diag, axis=0)
        moment_tolerance = 3.0 * np.max(diag_sd) / np.sqrt(num_draws)
        print('Wishart e log diag test tolerance: ', moment_tolerance)
        np_test.assert_allclose(
            diag_mean, ef.e_log_inv_wishart_diag(df, v),
            atol=moment_tolerance)

        # Test the LKJ prior
        lkj_param = 5.5
        def get_r_matrix(mat):
            mat_diag = np.diag(1. / np.sqrt(np.diag(mat)))
            return np.matmul(mat_diag, np.matmul(mat, mat_diag))

        wishart_log_det_r_draws = \
            np.array([ np.linalg.slogdet(get_r_matrix(mat))[1] \
                       for mat in wishart_inv_draws ]) * (lkj_param - 1)

        moment_tolerance = \
            3.0 * np.std(wishart_log_det_r_draws) / np.sqrt(num_draws)
        print('Wishart lkj prior test tolerance: ', moment_tolerance)
        np_test.assert_allclose(
            np.mean(wishart_log_det_r_draws),
            ef.expected_ljk_prior(lkj_param, df, v),
            atol=moment_tolerance)
 def e_log_lkj_inv_prior(self, lkj_param):
     return ef.expected_ljk_prior(lkj_param, self['df'].get(),
                                  self['v'].get())