Пример #1
0
def test_reshape_reshape(sample_dim, extra_event_dims):
    batch_dim = 2
    batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3))
    sample_shape = torch.Size((8, 7))[2 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(event_shape))
    dist1 = dist0.expand_by(batch_shape).independent(2)
    assert dist1.event_shape == event_shape
    assert dist1.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist1.expand_by(sample_shape).independent(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
    else:
        assert dist.enumerate_support().shape == torch.Size((2,)) + shape
Пример #2
0
def test_reshape_reshape(sample_dim, extra_event_dims):
    batch_dim = 2
    batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3))
    sample_shape = torch.Size((8, 7))[2 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(event_shape))
    dist1 = dist0.expand_by(batch_shape).independent(2)
    assert dist1.event_shape == event_shape
    assert dist1.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist1.expand_by(sample_shape).independent(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim -
                                                extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
    else:
        assert dist.enumerate_support().shape == torch.Size((2, )) + shape
Пример #3
0
def test_reshape(sample_dim, extra_event_dims):
    batch_dim = 3
    batch_shape, event_shape = torch.Size((5, 4, 3)), torch.Size()
    sample_shape = torch.Size((8, 7, 6))[3 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(batch_shape))
    assert dist0.event_shape == event_shape
    assert dist0.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist0.expand_by(sample_shape).to_event(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim -
                                                extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
        with pytest.raises(NotImplementedError):
            dist.enumerate_support(expand=True)
        with pytest.raises(NotImplementedError):
            dist.enumerate_support(expand=False)
    else:
        assert dist.enumerate_support().shape == (2, ) + shape
        assert dist.enumerate_support(expand=True).shape == (2, ) + shape
        assert dist.enumerate_support(expand=False).shape == (
            2, ) + (1, ) * len(sample_shape + batch_shape) + event_shape
Пример #4
0
def test_mask(batch_dim, event_dim, mask_dim):
    # Construct base distribution.
    shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
    batch_shape = shape[:batch_dim]
    mask_shape = batch_shape[batch_dim - mask_dim:]
    base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim)

    # Construct masked distribution.
    mask = checker_mask(mask_shape)
    dist = base_dist.mask(mask)

    # Check shape.
    sample = base_dist.sample()
    assert dist.batch_shape == base_dist.batch_shape
    assert dist.event_shape == base_dist.event_shape
    assert sample.shape == sample.shape
    assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

    # Check values.
    assert_equal(dist.mean, base_dist.mean)
    assert_equal(dist.variance, base_dist.variance)
    assert_equal(dist.log_prob(sample),
                 scale_and_mask(base_dist.log_prob(sample), mask=mask))
    assert_equal(dist.score_parts(sample),
                 base_dist.score_parts(sample).scale_and_mask(mask=mask),
                 prec=0)
    if not dist.event_shape:
        assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
        assert_equal(dist.enumerate_support(expand=True),
                     base_dist.enumerate_support(expand=True))
        assert_equal(dist.enumerate_support(expand=False),
                     base_dist.enumerate_support(expand=False))
Пример #5
0
def test_sample_shape_order():
    shape12 = torch.Size((1, 2))
    shape34 = torch.Size((3, 4))
    d = Bernoulli(0.5)

    # .expand_by(...) should add dimensions on the left.
    actual = d.expand_by(shape34).expand_by(shape12)
    expected = d.expand_by(shape12 + shape34)
    assert actual.event_shape == expected.event_shape
    assert actual.batch_shape == expected.batch_shape
Пример #6
0
def test_sample_shape_order():
    shape12 = torch.Size((1, 2))
    shape34 = torch.Size((3, 4))
    d = Bernoulli(0.5)

    # .expand_by(...) should add dimensions on the left.
    actual = d.expand_by(shape34).expand_by(shape12)
    expected = d.expand_by(shape12 + shape34)
    assert actual.event_shape == expected.event_shape
    assert actual.batch_shape == expected.batch_shape
Пример #7
0
def test_idempotent(batch_dim, event_dim):
    shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim]
    batch_shape = shape[:batch_dim]
    event_shape = shape[batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check that an .expand_by() an empty shape is a no-op.
    dist = dist0.expand_by([])
    assert dist.batch_shape == dist0.batch_shape
    assert dist.event_shape == dist0.event_shape
Пример #8
0
def test_idempotent(batch_dim, event_dim):
    shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim]
    batch_shape = shape[:batch_dim]
    event_shape = shape[batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check that an .expand_by() an empty shape is a no-op.
    dist = dist0.expand_by([])
    assert dist.batch_shape == dist0.batch_shape
    assert dist.event_shape == dist0.event_shape
Пример #9
0
def test_mask(batch_dim, event_dim, mask_dim):
    # Construct base distribution.
    shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
    batch_shape = shape[:batch_dim]
    mask_shape = batch_shape[batch_dim - mask_dim:]
    base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim)

    # Construct masked distribution.
    mask = checker_mask(mask_shape)
    dist = base_dist.mask(mask)

    # Check shape.
    sample = base_dist.sample()
    assert dist.batch_shape == base_dist.batch_shape
    assert dist.event_shape == base_dist.event_shape
    assert sample.shape == sample.shape
    assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

    # Check values.
    assert_equal(dist.mean, base_dist.mean)
    assert_equal(dist.variance, base_dist.variance)
    assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask)
    assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0)
    if not dist.event_shape:
        assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
Пример #10
0
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim):
    shape = torch.Size(range(sample_dim + batch_dim + event_dim))
    sample_shape = shape[:sample_dim]
    batch_shape = shape[sample_dim:sample_dim+batch_dim]
    event_shape = shape[sample_dim + batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(batch_shape + event_shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check .independent(...) for valid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim):
        dist = dist0.expand_by(sample_shape).independent(extra_event_dims)
        assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims]
        assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:]

    # Check .independent(...) for invalid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim, 20):
        with pytest.raises(ValueError):
            dist0.expand_by(sample_shape).independent(extra_event_dims)
Пример #11
0
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim):
    shape = torch.Size(range(sample_dim + batch_dim + event_dim))
    sample_shape = shape[:sample_dim]
    batch_shape = shape[sample_dim:sample_dim + batch_dim]
    event_shape = shape[sample_dim + batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(batch_shape +
                                     event_shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check .independent(...) for valid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim):
        dist = dist0.expand_by(sample_shape).independent(extra_event_dims)
        assert dist.batch_shape == shape[:sample_dim + batch_dim -
                                         extra_event_dims]
        assert dist.event_shape == shape[sample_dim + batch_dim -
                                         extra_event_dims:]

    # Check .independent(...) for invalid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim, 20):
        with pytest.raises(ValueError):
            dist0.expand_by(sample_shape).independent(extra_event_dims)
Пример #12
0
def test_independent_entropy():
    dist_univ = Bernoulli(0.5)
    dist_multi = Bernoulli(torch.Tensor([0.5, 0.5])).to_event(1)
    assert_equal(dist_multi.entropy(), 2 * dist_univ.entropy())
Пример #13
0
def test_mask_invalid_shape(batch_shape, mask_shape):
    dist = Bernoulli(0.1).expand_by(batch_shape)
    mask = checker_mask(mask_shape)
    with pytest.raises(ValueError):
        dist.mask(mask)