示例#1
0
def test_sample_is_valid(num_leaves):
    pyro.set_rng_seed(num_leaves)

    # Check with disperse leaves.
    leaf_times = torch.randn(num_leaves)
    coal_times = _sample_coalescent_times(leaf_times)
    assert CoalescentTimesConstraint(leaf_times).check(coal_times)
    assert len(set(coal_times.tolist())) == len(coal_times)

    # Check with simultaneous leaves.
    leaf_times = torch.zeros(num_leaves)
    coal_times = _sample_coalescent_times(leaf_times)
    assert CoalescentTimesConstraint(leaf_times).check(coal_times)
    assert len(set(coal_times.tolist())) == len(coal_times)
示例#2
0
def test_with_rate_smoke(num_leaves, num_steps, leaf_times_shape, rate_grid_shape, sample_shape):
    batch_shape = broadcast_shape(leaf_times_shape, rate_grid_shape)
    leaf_times = torch.rand(leaf_times_shape + (num_leaves,)).pow(0.5) * num_steps
    rate_grid = torch.rand(rate_grid_shape + (num_steps,))
    d = CoalescentTimesWithRate(leaf_times, rate_grid)
    coal_times = _sample_coalescent_times(
        leaf_times.expand(sample_shape + batch_shape + (-1,)))
    assert coal_times.shape == sample_shape + batch_shape + (num_leaves-1,)

    actual = d.log_prob(coal_times)
    assert actual.shape == sample_shape + batch_shape