示例#1
0
def test_expand_error(dist, initial_shape, proposed_shape, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
        else:
            with xfail_if_not_implemented():
                large = small.expand(torch.Size(initial_shape) + small.batch_shape)
        proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
        with pytest.raises((RuntimeError, ValueError)):
            large.expand(proposed_batch_shape)
示例#2
0
def test_expand_new_dim(dist, sample_shape, shape_type, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape))
        else:
            with xfail_if_not_implemented():
                large = small.expand(shape_type(sample_shape + small.batch_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        if dist.get_test_distribution_name() == 'Stable':
            pytest.skip('Stable does not implement a log_prob method.')
        check_sample_shapes(small, large)
示例#3
0
def test_expand_error(dist, initial_shape, proposed_shape, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
        else:
            with xfail_if_not_implemented():
                large = small.expand(torch.Size(initial_shape) + small.batch_shape)
        proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
        if dist.get_test_distribution_name() == 'LKJCorrCholesky':
            pytest.skip('LKJCorrCholesky can expand to a shape not' +
                        'broadcastable with its original batch_shape.')
        with pytest.raises((RuntimeError, ValueError)):
            large.expand(proposed_batch_shape)
示例#4
0
def test_subsequent_expands_ok(dist, sample_shapes, default):
    for idx in range(dist.get_num_test_data()):
        d = dist.pyro_dist(**dist.get_dist_params(idx))
        original_batch_shape = d.batch_shape
        for shape in sample_shapes:
            proposed_batch_shape = torch.Size(shape) + original_batch_shape
            if default:
                n = TorchDistribution.expand(d, proposed_batch_shape)
            else:
                with xfail_if_not_implemented():
                    n = d.expand(proposed_batch_shape)
            assert n.batch_shape == proposed_batch_shape
            with xfail_if_not_implemented():
                check_sample_shapes(d, n)
            d = n
示例#5
0
def test_expand_existing_dim(dist, shape_type, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            if default:
                large = TorchDistribution.expand(small, shape_type(batch_shape))
            else:
                with xfail_if_not_implemented():
                    large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            if dist.get_test_distribution_name() == 'Stable':
                pytest.skip('Stable does not implement a log_prob method.')
            check_sample_shapes(small, large)
示例#6
0
def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default):
    probs = torch.ones(1, 6) / 6
    d = dist.OneHotCategorical(probs)
    full_shape = torch.Size([4, 1, 1, 1, 6])
    if default:
        reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event(extra_event_dims)
    else:
        reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims)
    cut = 4 - extra_event_dims
    batch_shape, event_shape = full_shape[:cut], full_shape[cut:]
    assert reshaped_dist.batch_shape == batch_shape
    assert reshaped_dist.event_shape == event_shape
    large = reshaped_dist.expand(expand_shape)
    assert large.batch_shape == torch.Size(expand_shape)
    assert large.event_shape == torch.Size(event_shape)

    # Throws error when batch shape cannot be broadcasted
    with pytest.raises((RuntimeError, ValueError)):
        reshaped_dist.expand(expand_shape + [3])

    # Throws error when trying to shrink existing batch shape
    with pytest.raises((RuntimeError, ValueError)):
        large.expand(expand_shape[1:])
示例#7
0
def test_expand_enumerate_support():
    probs = torch.ones(3, 6) / 6
    d = dist.Categorical(probs)
    actual_enum_shape = (TorchDistribution.expand(
        d, (4, 3)).enumerate_support(expand=True).shape)
    assert actual_enum_shape == (6, 4, 3)