def test_wishart_moments(self): num_draws = 10000 df = 4.3 v = np.diag(np.array([2., 3.])) + np.full((2, 2), 0.1) wishart_dist = sp.stats.wishart(df=df, scale=v) wishart_draws = wishart_dist.rvs(num_draws) log_det_draws = np.linalg.slogdet(wishart_draws)[1] moment_tolerance = 3.0 * np.std(log_det_draws) / np.sqrt(num_draws) print('Wishart e log det test tolerance: ', moment_tolerance) np_test.assert_allclose( np.mean(log_det_draws), ef.e_log_det_wishart(df, v), atol=moment_tolerance) # Test the log inverse diagonals wishart_inv_draws = \ [ np.linalg.inv(wishart_draws[n, :, :]) for n in range(num_draws) ] wishart_log_diag = \ np.log([ np.diag(mat) for mat in wishart_inv_draws ]) diag_mean = np.mean(wishart_log_diag, axis=0) diag_sd = np.std(wishart_log_diag, axis=0) moment_tolerance = 3.0 * np.max(diag_sd) / np.sqrt(num_draws) print('Wishart e log diag test tolerance: ', moment_tolerance) np_test.assert_allclose( diag_mean, ef.e_log_inv_wishart_diag(df, v), atol=moment_tolerance) # Test the LKJ prior lkj_param = 5.5 def get_r_matrix(mat): mat_diag = np.diag(1. / np.sqrt(np.diag(mat))) return np.matmul(mat_diag, np.matmul(mat, mat_diag)) wishart_log_det_r_draws = \ np.array([ np.linalg.slogdet(get_r_matrix(mat))[1] \ for mat in wishart_inv_draws ]) * (lkj_param - 1) moment_tolerance = \ 3.0 * np.std(wishart_log_det_r_draws) / np.sqrt(num_draws) print('Wishart lkj prior test tolerance: ', moment_tolerance) np_test.assert_allclose( np.mean(wishart_log_det_r_draws), ef.expected_ljk_prior(lkj_param, df, v), atol=moment_tolerance)
def e_log_det(self): return ef.e_log_det_wishart(self['df'].get(), self['v'].get())