def test_support_shape(dist): for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) assert d.support.event_dim == d.event_dim x = dist.get_test_data(idx) ok = d.support.check(x) assert ok.shape == broadcast_shape(d.batch_shape, x.shape[:x.dim() - d.event_dim]) assert ok.all()
def test_batch_log_prob_shape(dist): for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) x = dist.get_test_data(idx) with xfail_if_not_implemented(): # Get log_prob shape after broadcasting. expected_shape = _log_prob_shape(d, x.size()) log_p_obj = d.log_prob(x) assert log_p_obj.size() == expected_shape
def test_batch_log_prob(dist): if dist.scipy_arg_fn is None: pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__)) for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) test_data = dist.get_test_data(idx) log_prob_sum_pyro = d.log_prob(test_data).sum().item() log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1)) assert_equal(log_prob_sum_pyro, log_prob_sum_np)