コード例 #1
0
ファイル: test_distributions.py プロジェクト: zyxue/pyro
def test_expand_new_dim(dist, sample_shape, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(shape_type(sample_shape + small.batch_shape))
            assert large.batch_shape == sample_shape + small.batch_shape
            check_sample_shapes(small, large)
コード例 #2
0
ファイル: test_distributions.py プロジェクト: zyxue/pyro
def test_expand_error(dist, initial_shape, proposed_shape):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(torch.Size(initial_shape) + small.batch_shape)
            proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
            with pytest.raises(RuntimeError):
                large.expand(proposed_batch_shape)
コード例 #3
0
def test_expand_by(dist, sample_shape, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        large = small.expand_by(shape_type(sample_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        if dist.get_test_distribution_name() == 'Stable':
            pytest.skip('Stable does not implement a log_prob method.')
        check_sample_shapes(small, large)
コード例 #4
0
def test_score_errors_non_broadcastable_data_shape(dist):
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        shape = d.shape()
        non_broadcastable_shape = (shape[0] + 1,) + shape[1:]
        test_data_non_broadcastable = torch.ones(non_broadcastable_shape)
        with pytest.raises((ValueError, RuntimeError)):
            d.log_prob(test_data_non_broadcastable)
コード例 #5
0
def test_batch_entropy_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        with xfail_if_not_implemented():
            # Get entropy shape after broadcasting.
            expected_shape = _log_prob_shape(d)
            entropy_obj = d.entropy()
            assert entropy_obj.size() == expected_shape
コード例 #6
0
def test_support_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        assert d.support.event_dim == d.event_dim
        x = dist.get_test_data(idx)
        ok = d.support.check(x)
        assert ok.shape == broadcast_shape(d.batch_shape, x.shape[:x.dim() - d.event_dim])
        assert ok.all()
コード例 #7
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_score_errors_non_broadcastable_data_shape(dist):
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        shape = d.shape()
        non_broadcastable_shape = (shape[0] + 1,) + shape[1:]
        test_data_non_broadcastable = torch.ones(non_broadcastable_shape)
        with pytest.raises((ValueError, RuntimeError)):
            d.log_prob(test_data_non_broadcastable)
コード例 #8
0
def test_batch_log_prob(dist):
    if dist.scipy_arg_fn is None:
        pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__))
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        test_data = dist.get_test_data(idx)
        log_prob_sum_pyro = d.log_prob(test_data).sum().item()
        log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1))
        assert_equal(log_prob_sum_pyro, log_prob_sum_np)
コード例 #9
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_batch_log_prob_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get log_prob shape after broadcasting.
            expected_shape = _log_prob_shape(d, x.size())
            log_p_obj = d.log_prob(x)
            assert log_p_obj.size() == expected_shape
コード例 #10
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_batch_log_prob(dist):
    if dist.scipy_arg_fn is None:
        pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__))
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        test_data = dist.get_test_data(idx)
        log_prob_sum_pyro = d.log_prob(test_data).sum().item()
        log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1))
        assert_equal(log_prob_sum_pyro, log_prob_sum_np)
コード例 #11
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_expand_twice(dist):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        medium = small.expand(torch.Size((2, 1)) + small.batch_shape)
        batch_shape = torch.Size((2, 3)) + small.batch_shape
        with xfail_if_not_implemented():
            large = medium.expand(batch_shape)
        assert large.batch_shape == batch_shape
        check_sample_shapes(small, large)
        check_sample_shapes(medium, large)
コード例 #12
0
ファイル: test_distributions.py プロジェクト: zippeurfou/pyro
def test_expand_twice(dist):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        medium = small.expand(torch.Size((2, 1)) + small.batch_shape)
        batch_shape = torch.Size((2, 3)) + small.batch_shape
        with xfail_if_not_implemented():
            large = medium.expand(batch_shape)
        assert large.batch_shape == batch_shape
        check_sample_shapes(small, large)
        check_sample_shapes(medium, large)
コード例 #13
0
def test_batch_log_prob_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get log_prob shape after broadcasting.
            expected_shape = _log_prob_shape(d, x.size())
            log_p_obj = d.log_prob(x)
            assert log_p_obj.size() == expected_shape
コード例 #14
0
def test_expand_error(dist, initial_shape, proposed_shape):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(torch.Size(initial_shape) + small.batch_shape)
            proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
            if dist.get_test_distribution_name() == 'LKJCorrCholesky':
                pytest.skip('LKJCorrCholesky can expand to a shape not' +
                            'broadcastable with its original batch_shape.')
            with pytest.raises(RuntimeError):
                large.expand(proposed_batch_shape)
コード例 #15
0
def test_infer_shapes(dist):
    if "LKJ" in dist.pyro_dist.__name__:
        pytest.xfail(reason="cannot statically compute shape")
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else ()
                      for k, v in dist_params.items()}
        batch_shape, event_shape = dist.pyro_dist.infer_shapes(**arg_shapes)
        d = dist.pyro_dist(**dist_params)
        assert d.batch_shape == batch_shape
        assert d.event_shape == event_shape
コード例 #16
0
def test_expand_error(dist, initial_shape, proposed_shape, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
        else:
            with xfail_if_not_implemented():
                large = small.expand(torch.Size(initial_shape) + small.batch_shape)
        proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
        with pytest.raises((RuntimeError, ValueError)):
            large.expand(proposed_batch_shape)
コード例 #17
0
ファイル: test_distributions.py プロジェクト: pyro-ppl/pyro
def test_score_errors_non_broadcastable_data_shape(dist):
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        if dist.get_test_distribution_name() == "LKJCholesky":
            pytest.skip("https://github.com/pytorch/pytorch/issues/52724")
        shape = d.shape()
        non_broadcastable_shape = (shape[0] + 1, ) + shape[1:]
        test_data_non_broadcastable = torch.ones(non_broadcastable_shape)
        with pytest.raises((ValueError, RuntimeError)):
            d.log_prob(test_data_non_broadcastable)
コード例 #18
0
ファイル: test_distributions.py プロジェクト: zyxue/pyro
def test_subsequent_expands_ok(dist, sample_shapes):
    for idx in range(dist.get_num_test_data()):
        d = dist.pyro_dist(**dist.get_dist_params(idx))
        original_batch_shape = d.batch_shape
        for shape in sample_shapes:
            proposed_batch_shape = torch.Size(shape) + original_batch_shape
            with xfail_if_not_implemented():
                n = d.expand(proposed_batch_shape)
            assert n.batch_shape == proposed_batch_shape
            check_sample_shapes(d, n)
            d = n
コード例 #19
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_score_errors_event_dim_mismatch(dist):
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        test_data_wrong_dims = torch.ones(d.shape() + (1,))
        if len(d.event_shape) > 0:
            if dist.get_test_distribution_name() == 'MultivariateNormal':
                pytest.skip('MultivariateNormal does not do shape validation in log_prob.')
            if dist.get_test_distribution_name() == 'LowRankMultivariateNormal':
                pytest.skip('LowRankMultivariateNormal does not do shape validation in log_prob.')
            with pytest.raises((ValueError, RuntimeError)):
                d.log_prob(test_data_wrong_dims)
コード例 #20
0
def test_score_errors_event_dim_mismatch(dist):
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        test_data_wrong_dims = torch.ones(d.shape() + (1,))
        if len(d.event_shape) > 0:
            if dist.get_test_distribution_name() == 'MultivariateNormal':
                pytest.skip('MultivariateNormal does not do shape validation in log_prob.')
            elif dist.get_test_distribution_name() == 'LowRankMultivariateNormal':
                pytest.skip('LowRankMultivariateNormal does not do shape validation in log_prob.')
            with pytest.raises((ValueError, RuntimeError)):
                d.log_prob(test_data_wrong_dims)
コード例 #21
0
def test_expand_new_dim(dist, sample_shape, shape_type, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape))
        else:
            with xfail_if_not_implemented():
                large = small.expand(shape_type(sample_shape + small.batch_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        if dist.get_test_distribution_name() == 'Stable':
            pytest.skip('Stable does not implement a log_prob method.')
        check_sample_shapes(small, large)
コード例 #22
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_expand_existing_dim(dist, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            with xfail_if_not_implemented():
                large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            check_sample_shapes(small, large)
コード例 #23
0
ファイル: test_distributions.py プロジェクト: zippeurfou/pyro
def test_expand_existing_dim(dist, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            with xfail_if_not_implemented():
                large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            check_sample_shapes(small, large)
コード例 #24
0
def test_enumerate_support_shape(dist):
    if not dist.pyro_dist.has_enumerate_support:
        pytest.skip()
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        with xfail_if_not_implemented():
            support = d.enumerate_support()
            n = support.shape[0]
            assert support.shape == (n,) + d.batch_shape + d.event_shape

            support_expanded = d.enumerate_support(expand=True)
            assert_equal(support, support_expanded)

            support_unexpanded = d.enumerate_support(expand=False)
            assert support_unexpanded.shape == (n,) + (1,) * len(d.batch_shape) + d.event_shape
            assert (support_expanded == support_unexpanded).all()
コード例 #25
0
def test_expand_existing_dim(dist, shape_type, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            if default:
                large = TorchDistribution.expand(small, shape_type(batch_shape))
            else:
                with xfail_if_not_implemented():
                    large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            if dist.get_test_distribution_name() == 'Stable':
                pytest.skip('Stable does not implement a log_prob method.')
            check_sample_shapes(small, large)
コード例 #26
0
ファイル: test_distributions.py プロジェクト: lewisKit/pyro
def test_expand_new_dim(dist, sample_shape, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        large = small.expand(shape_type(sample_shape + small.batch_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        check_sample_shapes(small, large)
コード例 #27
0
ファイル: test_distributions.py プロジェクト: zippeurfou/pyro
def test_expand_by(dist, sample_shape, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        large = small.expand_by(shape_type(sample_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        check_sample_shapes(small, large)