Example #1
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    samples = empirical_dist.sample(sample_shape=sample_shape)
    assert_equal(samples.size(), sample_shape + batch_shape)
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Example #2
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    samples = empirical_dist.sample(sample_shape=sample_shape)
    assert_equal(samples.size(), sample_shape + batch_shape)
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Example #3
0
def test_sample_examples(sample, weights, expected_mean, expected_var):
    emp_dist = Empirical(sample, weights)
    num_samples = 10000
    assert_equal(emp_dist.mean, expected_mean)
    assert_equal(emp_dist.variance, expected_var)
    emp_samples = emp_dist.sample((num_samples,))
    assert_close(emp_samples.mean(0), emp_dist.mean, rtol=1e-2)
    assert_close(emp_samples.var(0), emp_dist.variance, rtol=1e-2)
Example #4
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    samples = []
    for i in range(5):
        samples.append(torch.ones(batch_shape, dtype=dtype) * i)
    samples = torch.stack(samples)
    empirical_dist = Empirical(samples, torch.ones(5))
    samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape))
    assert_equal(samples.size(), torch.Size(sample_shape + batch_shape))
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Example #5
0
def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype):
    agg_dim_size = 5
    # empirical samples with desired shape
    dim_ordering = list(range(len(batch_shape + event_shape) + 1))  # +1 for agg dim
    dim_ordering.insert(len(batch_shape), dim_ordering.pop())
    emp_samples = torch.arange(agg_dim_size, dtype=dtype)\
        .expand(batch_shape + event_shape + [agg_dim_size])\
        .permute(dim_ordering)
    # initial weight assignment
    weights = torch.ones(batch_shape + [agg_dim_size])
    empirical_dist = Empirical(emp_samples, weights)
    samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape))
    assert_equal(samples.size(), torch.Size(sample_shape + batch_shape + event_shape))
Example #6
0
def test_weighted_sample_coherence(event_shape, dtype):
    samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    empirical_dist = Empirical()
    for sample, weight in samples:
        empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight)
    assert_equal(empirical_dist.event_shape, event_shape)
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
Example #7
0
def test_weighted_sample_coherence(event_shape, dtype):
    data = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    samples, weights = [], []
    for sample, weight in data:
        samples.append(sample * torch.ones(event_shape, dtype=dtype))
        weights.append(torch.tensor(weight).log())
    samples, weights = torch.stack(samples), torch.stack(weights)
    empirical_dist = Empirical(samples, weights)
    assert_equal(empirical_dist.event_shape, torch.Size(event_shape))
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)