def test_batch_log_prob(dist): if dist.scipy_arg_fn is None: pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__)) for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) test_data = dist.get_test_data(idx) log_prob_sum_pyro = d.log_prob(test_data).sum().item() log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1)) assert_equal(log_prob_sum_pyro, log_prob_sum_np)