def test_sample_is_valid(num_leaves): pyro.set_rng_seed(num_leaves) # Check with disperse leaves. leaf_times = torch.randn(num_leaves) coal_times = _sample_coalescent_times(leaf_times) assert CoalescentTimesConstraint(leaf_times).check(coal_times) assert len(set(coal_times.tolist())) == len(coal_times) # Check with simultaneous leaves. leaf_times = torch.zeros(num_leaves) coal_times = _sample_coalescent_times(leaf_times) assert CoalescentTimesConstraint(leaf_times).check(coal_times) assert len(set(coal_times.tolist())) == len(coal_times)
def test_with_rate_smoke(num_leaves, num_steps, leaf_times_shape, rate_grid_shape, sample_shape): batch_shape = broadcast_shape(leaf_times_shape, rate_grid_shape) leaf_times = torch.rand(leaf_times_shape + (num_leaves,)).pow(0.5) * num_steps rate_grid = torch.rand(rate_grid_shape + (num_steps,)) d = CoalescentTimesWithRate(leaf_times, rate_grid) coal_times = _sample_coalescent_times( leaf_times.expand(sample_shape + batch_shape + (-1,))) assert coal_times.shape == sample_shape + batch_shape + (num_leaves-1,) actual = d.log_prob(coal_times) assert actual.shape == sample_shape + batch_shape