def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) batch_shape = shape[:batch_dim] mask_shape = batch_shape[batch_dim - mask_dim:] base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim) # Construct masked distribution. mask = checker_mask(mask_shape) dist = base_dist.mask(mask) # Check shape. sample = base_dist.sample() assert dist.batch_shape == base_dist.batch_shape assert dist.event_shape == base_dist.event_shape assert sample.shape == sample.shape assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask) assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
def test_mask_invalid_shape(batch_shape, mask_shape): dist = Bernoulli(0.1).expand_by(batch_shape) mask = checker_mask(mask_shape) with pytest.raises(ValueError): dist.mask(mask)