def test_expand_error(dist, initial_shape, proposed_shape, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) if default: large = TorchDistribution.expand(small, initial_shape + small.batch_shape) else: with xfail_if_not_implemented(): large = small.expand(torch.Size(initial_shape) + small.batch_shape) proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape with pytest.raises((RuntimeError, ValueError)): large.expand(proposed_batch_shape)
def test_expand_new_dim(dist, sample_shape, shape_type, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) if default: large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape)) else: with xfail_if_not_implemented(): large = small.expand(shape_type(sample_shape + small.batch_shape)) assert large.batch_shape == sample_shape + small.batch_shape if dist.get_test_distribution_name() == 'Stable': pytest.skip('Stable does not implement a log_prob method.') check_sample_shapes(small, large)
def test_expand_error(dist, initial_shape, proposed_shape, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) if default: large = TorchDistribution.expand(small, initial_shape + small.batch_shape) else: with xfail_if_not_implemented(): large = small.expand(torch.Size(initial_shape) + small.batch_shape) proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape if dist.get_test_distribution_name() == 'LKJCorrCholesky': pytest.skip('LKJCorrCholesky can expand to a shape not' + 'broadcastable with its original batch_shape.') with pytest.raises((RuntimeError, ValueError)): large.expand(proposed_batch_shape)
def test_subsequent_expands_ok(dist, sample_shapes, default): for idx in range(dist.get_num_test_data()): d = dist.pyro_dist(**dist.get_dist_params(idx)) original_batch_shape = d.batch_shape for shape in sample_shapes: proposed_batch_shape = torch.Size(shape) + original_batch_shape if default: n = TorchDistribution.expand(d, proposed_batch_shape) else: with xfail_if_not_implemented(): n = d.expand(proposed_batch_shape) assert n.batch_shape == proposed_batch_shape with xfail_if_not_implemented(): check_sample_shapes(d, n) d = n
def test_expand_existing_dim(dist, shape_type, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) for dim, size in enumerate(small.batch_shape): if size != 1: continue batch_shape = list(small.batch_shape) batch_shape[dim] = 5 batch_shape = torch.Size(batch_shape) if default: large = TorchDistribution.expand(small, shape_type(batch_shape)) else: with xfail_if_not_implemented(): large = small.expand(shape_type(batch_shape)) assert large.batch_shape == batch_shape if dist.get_test_distribution_name() == 'Stable': pytest.skip('Stable does not implement a log_prob method.') check_sample_shapes(small, large)
def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default): probs = torch.ones(1, 6) / 6 d = dist.OneHotCategorical(probs) full_shape = torch.Size([4, 1, 1, 1, 6]) if default: reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event(extra_event_dims) else: reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims) cut = 4 - extra_event_dims batch_shape, event_shape = full_shape[:cut], full_shape[cut:] assert reshaped_dist.batch_shape == batch_shape assert reshaped_dist.event_shape == event_shape large = reshaped_dist.expand(expand_shape) assert large.batch_shape == torch.Size(expand_shape) assert large.event_shape == torch.Size(event_shape) # Throws error when batch shape cannot be broadcasted with pytest.raises((RuntimeError, ValueError)): reshaped_dist.expand(expand_shape + [3]) # Throws error when trying to shrink existing batch shape with pytest.raises((RuntimeError, ValueError)): large.expand(expand_shape[1:])
def test_expand_enumerate_support(): probs = torch.ones(3, 6) / 6 d = dist.Categorical(probs) actual_enum_shape = (TorchDistribution.expand( d, (4, 3)).enumerate_support(expand=True).shape) assert actual_enum_shape == (6, 4, 3)