コード例 #1
0
def test_sample_shape_order():
    shape12 = torch.Size((1, 2))
    shape34 = torch.Size((3, 4))
    d = Bernoulli(0.5)

    # .expand_by(...) should add dimensions on the left.
    actual = d.expand_by(shape34).expand_by(shape12)
    expected = d.expand_by(shape12 + shape34)
    assert actual.event_shape == expected.event_shape
    assert actual.batch_shape == expected.batch_shape
コード例 #2
0
ファイル: test_reshape.py プロジェクト: lewisKit/pyro
def test_sample_shape_order():
    shape12 = torch.Size((1, 2))
    shape34 = torch.Size((3, 4))
    d = Bernoulli(0.5)

    # .expand_by(...) should add dimensions on the left.
    actual = d.expand_by(shape34).expand_by(shape12)
    expected = d.expand_by(shape12 + shape34)
    assert actual.event_shape == expected.event_shape
    assert actual.batch_shape == expected.batch_shape
コード例 #3
0
ファイル: test_reshape.py プロジェクト: lewisKit/pyro
def test_reshape_reshape(sample_dim, extra_event_dims):
    batch_dim = 2
    batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3))
    sample_shape = torch.Size((8, 7))[2 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(event_shape))
    dist1 = dist0.expand_by(batch_shape).independent(2)
    assert dist1.event_shape == event_shape
    assert dist1.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist1.expand_by(sample_shape).independent(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
    else:
        assert dist.enumerate_support().shape == torch.Size((2,)) + shape
コード例 #4
0
def test_reshape_reshape(sample_dim, extra_event_dims):
    batch_dim = 2
    batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3))
    sample_shape = torch.Size((8, 7))[2 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(event_shape))
    dist1 = dist0.expand_by(batch_shape).independent(2)
    assert dist1.event_shape == event_shape
    assert dist1.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist1.expand_by(sample_shape).independent(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim -
                                                extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
    else:
        assert dist.enumerate_support().shape == torch.Size((2, )) + shape
コード例 #5
0
def test_reshape(sample_dim, extra_event_dims):
    batch_dim = 3
    batch_shape, event_shape = torch.Size((5, 4, 3)), torch.Size()
    sample_shape = torch.Size((8, 7, 6))[3 - sample_dim:]
    shape = sample_shape + batch_shape + event_shape

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5 * torch.ones(batch_shape))
    assert dist0.event_shape == event_shape
    assert dist0.batch_shape == batch_shape

    # Check that reshaping has the desired final shape.
    dist = dist0.expand_by(sample_shape).to_event(extra_event_dims)
    sample = dist.sample()
    assert sample.shape == shape
    assert dist.mean.shape == shape
    assert dist.variance.shape == shape
    assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim -
                                                extra_event_dims]

    # Check enumerate support.
    if dist.event_shape:
        with pytest.raises(NotImplementedError):
            dist.enumerate_support()
        with pytest.raises(NotImplementedError):
            dist.enumerate_support(expand=True)
        with pytest.raises(NotImplementedError):
            dist.enumerate_support(expand=False)
    else:
        assert dist.enumerate_support().shape == (2, ) + shape
        assert dist.enumerate_support(expand=True).shape == (2, ) + shape
        assert dist.enumerate_support(expand=False).shape == (
            2, ) + (1, ) * len(sample_shape + batch_shape) + event_shape
コード例 #6
0
ファイル: test_reshape.py プロジェクト: lewisKit/pyro
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim):
    shape = torch.Size(range(sample_dim + batch_dim + event_dim))
    sample_shape = shape[:sample_dim]
    batch_shape = shape[sample_dim:sample_dim+batch_dim]
    event_shape = shape[sample_dim + batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(batch_shape + event_shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check .independent(...) for valid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim):
        dist = dist0.expand_by(sample_shape).independent(extra_event_dims)
        assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims]
        assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:]

    # Check .independent(...) for invalid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim, 20):
        with pytest.raises(ValueError):
            dist0.expand_by(sample_shape).independent(extra_event_dims)
コード例 #7
0
ファイル: test_reshape.py プロジェクト: lewisKit/pyro
def test_idempotent(batch_dim, event_dim):
    shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim]
    batch_shape = shape[:batch_dim]
    event_shape = shape[batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check that an .expand_by() an empty shape is a no-op.
    dist = dist0.expand_by([])
    assert dist.batch_shape == dist0.batch_shape
    assert dist.event_shape == dist0.event_shape
コード例 #8
0
def test_idempotent(batch_dim, event_dim):
    shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim]
    batch_shape = shape[:batch_dim]
    event_shape = shape[batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check that an .expand_by() an empty shape is a no-op.
    dist = dist0.expand_by([])
    assert dist.batch_shape == dist0.batch_shape
    assert dist.event_shape == dist0.event_shape
コード例 #9
0
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim):
    shape = torch.Size(range(sample_dim + batch_dim + event_dim))
    sample_shape = shape[:sample_dim]
    batch_shape = shape[sample_dim:sample_dim + batch_dim]
    event_shape = shape[sample_dim + batch_dim:]

    # Construct a base dist of desired starting shape.
    dist0 = Bernoulli(0.5).expand_by(batch_shape +
                                     event_shape).independent(event_dim)
    assert dist0.batch_shape == batch_shape
    assert dist0.event_shape == event_shape

    # Check .independent(...) for valid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim):
        dist = dist0.expand_by(sample_shape).independent(extra_event_dims)
        assert dist.batch_shape == shape[:sample_dim + batch_dim -
                                         extra_event_dims]
        assert dist.event_shape == shape[sample_dim + batch_dim -
                                         extra_event_dims:]

    # Check .independent(...) for invalid values.
    for extra_event_dims in range(1 + sample_dim + batch_dim, 20):
        with pytest.raises(ValueError):
            dist0.expand_by(sample_shape).independent(extra_event_dims)