def test_simple_smoke(num_leaves, num_steps, batch_shape, sample_shape): leaf_times = torch.rand(batch_shape + (num_leaves, )).pow(0.5) * num_steps d = CoalescentTimes(leaf_times) coal_times = d.sample(sample_shape) assert coal_times.shape == sample_shape + batch_shape + (num_leaves - 1, ) actual = d.log_prob(coal_times) assert actual.shape == sample_shape + batch_shape
def test_log_prob_unit_rate(num_leaves, num_steps, batch_shape, sample_shape): leaf_times = torch.rand(batch_shape + (num_leaves, )).pow(0.5) * num_steps d1 = CoalescentTimes(leaf_times) rate_grid = torch.ones(batch_shape + (num_steps, )) d2 = CoalescentTimesWithRate(leaf_times, rate_grid) coal_times = d1.sample(sample_shape) assert_close(d1.log_prob(coal_times), d2.log_prob(coal_times))
def test_log_prob_constant_rate_2(num_leaves, num_steps, batch_shape, sample_shape): rate = torch.randn(batch_shape).exp() rate_grid = rate.unsqueeze(-1).expand(batch_shape + (num_steps,)) leaf_times = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps d1 = CoalescentTimes(leaf_times, rate) coal_times = d1.sample(sample_shape) log_prob_1 = d1.log_prob(coal_times) d2 = CoalescentTimesWithRate(leaf_times, rate_grid) log_prob_2 = d2.log_prob(coal_times) assert_close(log_prob_1, log_prob_2)
def test_log_prob_constant_rate(num_leaves, num_steps, batch_shape, sample_shape): rate = torch.randn(batch_shape).exp() rate_grid = rate.unsqueeze(-1).expand(batch_shape + (num_steps,)) leaf_times_2 = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps leaf_times_1 = leaf_times_2 * rate.unsqueeze(-1) d1 = CoalescentTimes(leaf_times_1) coal_times_1 = d1.sample(sample_shape) log_prob_1 = d1.log_prob(coal_times_1) d2 = CoalescentTimesWithRate(leaf_times_2, rate_grid) coal_times_2 = coal_times_1 / rate.unsqueeze(-1) log_prob_2 = d2.log_prob(coal_times_2) log_abs_det_jacobian = -coal_times_2.size(-1) * rate.log() assert_close(log_prob_1 - log_abs_det_jacobian, log_prob_2)
def test_likelihood_vectorized(num_leaves, num_steps, batch_shape, clamped): if clamped: leaf_times = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps coal_times = CoalescentTimes(leaf_times).sample().clamp(min=0) else: leaf_times = torch.randn(batch_shape + (num_leaves,)) leaf_times.mul_(0.25).add_(0.75).mul_(num_steps) coal_times = CoalescentTimes(leaf_times).sample() rate_grid = torch.rand(batch_shape + (num_steps,)) + 0.5 d = CoalescentTimesWithRate(leaf_times, rate_grid) expected = d.log_prob(coal_times) likelihood = CoalescentRateLikelihood(leaf_times, coal_times, num_steps) actual = likelihood(rate_grid).sum(-1) assert_close(actual, expected)