def test_log_prob(batch_shape, event_shape, dtype): samples = [] for i in range(5): samples.append(torch.ones(event_shape, dtype=dtype) * i) samples = torch.stack(samples).expand(batch_shape + [5] + event_shape) weights = torch.tensor(1.0).expand(batch_shape + [5]) empirical_dist = Empirical(samples, weights) sample_to_score = torch.tensor(1, dtype=dtype).expand(batch_shape + event_shape) log_prob = empirical_dist.log_prob(sample_to_score) assert_equal(log_prob, (weights.new_ones(batch_shape + [1]) * 0.2).sum(-1).log()) # Value outside support returns -Inf sample_to_score = torch.tensor( 1, dtype=dtype).expand(batch_shape + event_shape) * 6 log_prob = empirical_dist.log_prob(sample_to_score) assert log_prob.shape == torch.Size(batch_shape) assert torch.isinf(log_prob).all() # Vectorized ``log_prob`` raises ValueError with pytest.raises(ValueError): sample_to_score = torch.ones([3] + batch_shape + event_shape, dtype=dtype) empirical_dist.log_prob(sample_to_score)
def test_log_prob(batch_shape, dtype): empirical_dist = Empirical() for i in range(5): empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i) sample_to_score = torch.ones(batch_shape, dtype=dtype) log_prob = empirical_dist.log_prob(sample_to_score) assert_equal(log_prob, torch.tensor(0.2).log()) # Value outside support returns -Inf sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6 log_prob = empirical_dist.log_prob(sample_to_score) assert log_prob == -float("inf") # Vectorized ``log_prob`` raises ValueError with pytest.raises(ValueError): sample_to_score = torch.ones((3,) + batch_shape, dtype=dtype) empirical_dist.log_prob(sample_to_score)
def test_log_prob(batch_shape, dtype): empirical_dist = Empirical() for i in range(5): empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i) sample_to_score = torch.ones(batch_shape, dtype=dtype) log_prob = empirical_dist.log_prob(sample_to_score) assert_equal(log_prob, torch.tensor(0.2).log()) # Value outside support returns -Inf sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6 log_prob = empirical_dist.log_prob(sample_to_score) assert log_prob == -float("inf") # Vectorized ``log_prob`` raises ValueError with pytest.raises(ValueError): sample_to_score = torch.ones((3, ) + batch_shape, dtype=dtype) empirical_dist.log_prob(sample_to_score)
def test_weighted_sample_coherence(event_shape, dtype): samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)] empirical_dist = Empirical() for sample, weight in samples: empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight) assert_equal(empirical_dist.event_shape, event_shape) assert_equal(empirical_dist.sample_size, 4) sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0 assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log()) samples = empirical_dist.sample(sample_shape=torch.Size((1000,))) zeros = torch.zeros(event_shape, dtype=dtype) ones = torch.ones(event_shape, dtype=dtype) num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02) assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
def test_weighted_sample_coherence(event_shape, dtype): data = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)] samples, weights = [], [] for sample, weight in data: samples.append(sample * torch.ones(event_shape, dtype=dtype)) weights.append(torch.tensor(weight).log()) samples, weights = torch.stack(samples), torch.stack(weights) empirical_dist = Empirical(samples, weights) assert_equal(empirical_dist.event_shape, torch.Size(event_shape)) assert_equal(empirical_dist.sample_size, 4) sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0 assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log()) samples = empirical_dist.sample(sample_shape=torch.Size((1000,))) zeros = torch.zeros(event_shape, dtype=dtype) ones = torch.ones(event_shape, dtype=dtype) num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02) assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)