def test_unweighted_samples(batch_shape, sample_shape, dtype): empirical_dist = Empirical() for i in range(5): empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i) samples = empirical_dist.sample(sample_shape=sample_shape) assert_equal(samples.size(), sample_shape + batch_shape) assert_equal(set(samples.view(-1).tolist()), set(range(5)))
def test_sample_examples(sample, weights, expected_mean, expected_var): emp_dist = Empirical(sample, weights) num_samples = 10000 assert_equal(emp_dist.mean, expected_mean) assert_equal(emp_dist.variance, expected_var) emp_samples = emp_dist.sample((num_samples,)) assert_close(emp_samples.mean(0), emp_dist.mean, rtol=1e-2) assert_close(emp_samples.var(0), emp_dist.variance, rtol=1e-2)
def test_unweighted_samples(batch_shape, sample_shape, dtype): samples = [] for i in range(5): samples.append(torch.ones(batch_shape, dtype=dtype) * i) samples = torch.stack(samples) empirical_dist = Empirical(samples, torch.ones(5)) samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape)) assert_equal(samples.size(), torch.Size(sample_shape + batch_shape)) assert_equal(set(samples.view(-1).tolist()), set(range(5)))
def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype): agg_dim_size = 5 # empirical samples with desired shape dim_ordering = list(range(len(batch_shape + event_shape) + 1)) # +1 for agg dim dim_ordering.insert(len(batch_shape), dim_ordering.pop()) emp_samples = torch.arange(agg_dim_size, dtype=dtype)\ .expand(batch_shape + event_shape + [agg_dim_size])\ .permute(dim_ordering) # initial weight assignment weights = torch.ones(batch_shape + [agg_dim_size]) empirical_dist = Empirical(emp_samples, weights) samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape)) assert_equal(samples.size(), torch.Size(sample_shape + batch_shape + event_shape))
def test_weighted_sample_coherence(event_shape, dtype): samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)] empirical_dist = Empirical() for sample, weight in samples: empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight) assert_equal(empirical_dist.event_shape, event_shape) assert_equal(empirical_dist.sample_size, 4) sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0 assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log()) samples = empirical_dist.sample(sample_shape=torch.Size((1000,))) zeros = torch.zeros(event_shape, dtype=dtype) ones = torch.ones(event_shape, dtype=dtype) num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02) assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
def test_weighted_sample_coherence(event_shape, dtype): data = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)] samples, weights = [], [] for sample, weight in data: samples.append(sample * torch.ones(event_shape, dtype=dtype)) weights.append(torch.tensor(weight).log()) samples, weights = torch.stack(samples), torch.stack(weights) empirical_dist = Empirical(samples, weights) assert_equal(empirical_dist.event_shape, torch.Size(event_shape)) assert_equal(empirical_dist.sample_size, 4) sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0 assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log()) samples = empirical_dist.sample(sample_shape=torch.Size((1000,))) zeros = torch.zeros(event_shape, dtype=dtype) ones = torch.ones(event_shape, dtype=dtype) num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02) assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)