def test_reshape_reshape(sample_dim, extra_event_dims): batch_dim = 2 batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3)) sample_shape = torch.Size((8, 7))[2 - sample_dim:] shape = sample_shape + batch_shape + event_shape # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5 * torch.ones(event_shape)) dist1 = dist0.expand_by(batch_shape).independent(2) assert dist1.event_shape == event_shape assert dist1.batch_shape == batch_shape # Check that reshaping has the desired final shape. dist = dist1.expand_by(sample_shape).independent(extra_event_dims) sample = dist.sample() assert sample.shape == shape assert dist.mean.shape == shape assert dist.variance.shape == shape assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims] # Check enumerate support. if dist.event_shape: with pytest.raises(NotImplementedError): dist.enumerate_support() else: assert dist.enumerate_support().shape == torch.Size((2,)) + shape
def test_reshape_reshape(sample_dim, extra_event_dims): batch_dim = 2 batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3)) sample_shape = torch.Size((8, 7))[2 - sample_dim:] shape = sample_shape + batch_shape + event_shape # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5 * torch.ones(event_shape)) dist1 = dist0.expand_by(batch_shape).independent(2) assert dist1.event_shape == event_shape assert dist1.batch_shape == batch_shape # Check that reshaping has the desired final shape. dist = dist1.expand_by(sample_shape).independent(extra_event_dims) sample = dist.sample() assert sample.shape == shape assert dist.mean.shape == shape assert dist.variance.shape == shape assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims] # Check enumerate support. if dist.event_shape: with pytest.raises(NotImplementedError): dist.enumerate_support() else: assert dist.enumerate_support().shape == torch.Size((2, )) + shape
def test_reshape(sample_dim, extra_event_dims): batch_dim = 3 batch_shape, event_shape = torch.Size((5, 4, 3)), torch.Size() sample_shape = torch.Size((8, 7, 6))[3 - sample_dim:] shape = sample_shape + batch_shape + event_shape # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5 * torch.ones(batch_shape)) assert dist0.event_shape == event_shape assert dist0.batch_shape == batch_shape # Check that reshaping has the desired final shape. dist = dist0.expand_by(sample_shape).to_event(extra_event_dims) sample = dist.sample() assert sample.shape == shape assert dist.mean.shape == shape assert dist.variance.shape == shape assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims] # Check enumerate support. if dist.event_shape: with pytest.raises(NotImplementedError): dist.enumerate_support() with pytest.raises(NotImplementedError): dist.enumerate_support(expand=True) with pytest.raises(NotImplementedError): dist.enumerate_support(expand=False) else: assert dist.enumerate_support().shape == (2, ) + shape assert dist.enumerate_support(expand=True).shape == (2, ) + shape assert dist.enumerate_support(expand=False).shape == ( 2, ) + (1, ) * len(sample_shape + batch_shape) + event_shape
def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) batch_shape = shape[:batch_dim] mask_shape = batch_shape[batch_dim - mask_dim:] base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim) # Construct masked distribution. mask = checker_mask(mask_shape) dist = base_dist.mask(mask) # Check shape. sample = base_dist.sample() assert dist.batch_shape == base_dist.batch_shape assert dist.event_shape == base_dist.event_shape assert sample.shape == sample.shape assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) assert_equal(dist.log_prob(sample), scale_and_mask(base_dist.log_prob(sample), mask=mask)) assert_equal(dist.score_parts(sample), base_dist.score_parts(sample).scale_and_mask(mask=mask), prec=0) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support()) assert_equal(dist.enumerate_support(expand=True), base_dist.enumerate_support(expand=True)) assert_equal(dist.enumerate_support(expand=False), base_dist.enumerate_support(expand=False))
def test_sample_shape_order(): shape12 = torch.Size((1, 2)) shape34 = torch.Size((3, 4)) d = Bernoulli(0.5) # .expand_by(...) should add dimensions on the left. actual = d.expand_by(shape34).expand_by(shape12) expected = d.expand_by(shape12 + shape34) assert actual.event_shape == expected.event_shape assert actual.batch_shape == expected.batch_shape
def test_sample_shape_order(): shape12 = torch.Size((1, 2)) shape34 = torch.Size((3, 4)) d = Bernoulli(0.5) # .expand_by(...) should add dimensions on the left. actual = d.expand_by(shape34).expand_by(shape12) expected = d.expand_by(shape12 + shape34) assert actual.event_shape == expected.event_shape assert actual.batch_shape == expected.batch_shape
def test_idempotent(batch_dim, event_dim): shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim] batch_shape = shape[:batch_dim] event_shape = shape[batch_dim:] # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim) assert dist0.batch_shape == batch_shape assert dist0.event_shape == event_shape # Check that an .expand_by() an empty shape is a no-op. dist = dist0.expand_by([]) assert dist.batch_shape == dist0.batch_shape assert dist.event_shape == dist0.event_shape
def test_idempotent(batch_dim, event_dim): shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim] batch_shape = shape[:batch_dim] event_shape = shape[batch_dim:] # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5).expand_by(shape).independent(event_dim) assert dist0.batch_shape == batch_shape assert dist0.event_shape == event_shape # Check that an .expand_by() an empty shape is a no-op. dist = dist0.expand_by([]) assert dist.batch_shape == dist0.batch_shape assert dist.event_shape == dist0.event_shape
def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) batch_shape = shape[:batch_dim] mask_shape = batch_shape[batch_dim - mask_dim:] base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim) # Construct masked distribution. mask = checker_mask(mask_shape) dist = base_dist.mask(mask) # Check shape. sample = base_dist.sample() assert dist.batch_shape == base_dist.batch_shape assert dist.event_shape == base_dist.event_shape assert sample.shape == sample.shape assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask) assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): shape = torch.Size(range(sample_dim + batch_dim + event_dim)) sample_shape = shape[:sample_dim] batch_shape = shape[sample_dim:sample_dim+batch_dim] event_shape = shape[sample_dim + batch_dim:] # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5).expand_by(batch_shape + event_shape).independent(event_dim) assert dist0.batch_shape == batch_shape assert dist0.event_shape == event_shape # Check .independent(...) for valid values. for extra_event_dims in range(1 + sample_dim + batch_dim): dist = dist0.expand_by(sample_shape).independent(extra_event_dims) assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims] assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:] # Check .independent(...) for invalid values. for extra_event_dims in range(1 + sample_dim + batch_dim, 20): with pytest.raises(ValueError): dist0.expand_by(sample_shape).independent(extra_event_dims)
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): shape = torch.Size(range(sample_dim + batch_dim + event_dim)) sample_shape = shape[:sample_dim] batch_shape = shape[sample_dim:sample_dim + batch_dim] event_shape = shape[sample_dim + batch_dim:] # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5).expand_by(batch_shape + event_shape).independent(event_dim) assert dist0.batch_shape == batch_shape assert dist0.event_shape == event_shape # Check .independent(...) for valid values. for extra_event_dims in range(1 + sample_dim + batch_dim): dist = dist0.expand_by(sample_shape).independent(extra_event_dims) assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims] assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:] # Check .independent(...) for invalid values. for extra_event_dims in range(1 + sample_dim + batch_dim, 20): with pytest.raises(ValueError): dist0.expand_by(sample_shape).independent(extra_event_dims)
def test_independent_entropy(): dist_univ = Bernoulli(0.5) dist_multi = Bernoulli(torch.Tensor([0.5, 0.5])).to_event(1) assert_equal(dist_multi.entropy(), 2 * dist_univ.entropy())
def test_mask_invalid_shape(batch_shape, mask_shape): dist = Bernoulli(0.1).expand_by(batch_shape) mask = checker_mask(mask_shape) with pytest.raises(ValueError): dist.mask(mask)