예제 #1
0
def test_log_prob(batch_shape, event_shape, dtype):
    samples = []
    for i in range(5):
        samples.append(torch.ones(event_shape, dtype=dtype) * i)
    samples = torch.stack(samples).expand(batch_shape + [5] + event_shape)
    weights = torch.tensor(1.0).expand(batch_shape + [5])
    empirical_dist = Empirical(samples, weights)
    sample_to_score = torch.tensor(1, dtype=dtype).expand(batch_shape +
                                                          event_shape)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob,
                 (weights.new_ones(batch_shape + [1]) * 0.2).sum(-1).log())

    # Value outside support returns -Inf
    sample_to_score = torch.tensor(
        1, dtype=dtype).expand(batch_shape + event_shape) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob.shape == torch.Size(batch_shape)
    assert torch.isinf(log_prob).all()

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones([3] + batch_shape + event_shape,
                                     dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
예제 #2
0
def test_log_prob(batch_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    sample_to_score = torch.ones(batch_shape, dtype=dtype)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob, torch.tensor(0.2).log())

    # Value outside support returns -Inf
    sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob == -float("inf")

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones((3,) + batch_shape, dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
예제 #3
0
def test_log_prob(batch_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    sample_to_score = torch.ones(batch_shape, dtype=dtype)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob, torch.tensor(0.2).log())

    # Value outside support returns -Inf
    sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob == -float("inf")

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones((3, ) + batch_shape, dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
예제 #4
0
def test_weighted_sample_coherence(event_shape, dtype):
    samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    empirical_dist = Empirical()
    for sample, weight in samples:
        empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight)
    assert_equal(empirical_dist.event_shape, event_shape)
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
예제 #5
0
def test_weighted_sample_coherence(event_shape, dtype):
    data = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    samples, weights = [], []
    for sample, weight in data:
        samples.append(sample * torch.ones(event_shape, dtype=dtype))
        weights.append(torch.tensor(weight).log())
    samples, weights = torch.stack(samples), torch.stack(weights)
    empirical_dist = Empirical(samples, weights)
    assert_equal(empirical_dist.event_shape, torch.Size(event_shape))
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)